{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ChainerRL Quickstart Guide\n",
    "\n",
    "This is a quickstart guide for users who just want to try ChainerRL for the first time.\n",
    "\n",
    "If you have not yet installed ChainerRL, run the command below to install it:\n",
    "```\n",
    "pip install chainerrl\n",
    "```\n",
    "\n",
    "If you have already installed ChainerRL, let's begin!\n",
    "\n",
    "First, you need to import necessary modules. The module name of ChainerRL is `chainerrl`. Let's import `gym` and `numpy` as well since they are used later."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import chainer\n",
    "import chainer.functions as F\n",
    "import chainer.links as L\n",
    "import chainerrl\n",
    "import gym\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ChainerRL can be used for any problems if they are modeled as \"environments\". [OpenAI Gym](https://github.com/openai/gym) provides various kinds of benchmark environments and defines the common interface among them. ChainerRL uses a subset of the interface. Specifically, an environment must define its observation space and action space and have at least two methods: `reset` and `step`.\n",
    "\n",
    "- `env.reset` will reset the environment to the initial state and return the initial observation.\n",
    "- `env.step` will execute a given action, move to the next state and return four values:\n",
    "  - a next observation\n",
    "  - a scalar reward\n",
    "  - a boolean value indicating whether the current state is terminal or not\n",
    "  - additional information\n",
    "- `env.render` will render the current state.\n",
    "\n",
    "Let's try 'CartPole-v0', which is a classic control problem. You can see below that its observation space consists of four real numbers while its action space consists of two discrete actions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2017-10-09 17:35:11,952] Making new env: CartPole-v0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "observation space: Box(4,)\n",
      "action space: Discrete(2)\n",
      "initial observation: [-0.01102502  0.03448551  0.01958173  0.01880113]\n",
      "next observation: [-0.01033531 -0.16091171  0.01995775  0.31759744]\n",
      "reward: 1.0\n",
      "done: False\n",
      "info: {}\n"
     ]
    }
   ],
   "source": [
    "env = gym.make('CartPole-v0')\n",
    "print('observation space:', env.observation_space)\n",
    "print('action space:', env.action_space)\n",
    "\n",
    "obs = env.reset()\n",
    "env.render()\n",
    "print('initial observation:', obs)\n",
    "\n",
    "action = env.action_space.sample()\n",
    "obs, r, done, info = env.step(action)\n",
    "print('next observation:', obs)\n",
    "print('reward:', r)\n",
    "print('done:', done)\n",
    "print('info:', info)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now you have defined your environment. Next, you need to define an agent, which will learn through interactions with the environment.\n",
    "\n",
    "ChainerRL provides various agents, each of which implements a deep reinforcement learning algorithm.\n",
    "\n",
    "To use [DQN (Deep Q-Network)](https://doi.org/10.1038/nature14236), you need to define a Q-function that receives an observation and returns an expected future return for each action the agent can take. In ChainerRL, you can define your Q-function as `chainer.Link` as below. Note that the outputs are wrapped by `chainerrl.action_value.DiscreteActionValue`, which implements `chainerrl.action_value.ActionValue`. By wrapping the outputs of Q-functions, ChainerRL can treat discrete-action Q-functions like this and [NAFs (Normalized Advantage Functions)](https://arxiv.org/abs/1603.00748) in the same way."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "class QFunction(chainer.Chain):\n",
    "\n",
    "    def __init__(self, obs_size, n_actions, n_hidden_channels=50):\n",
    "        super().__init__()\n",
    "        with self.init_scope():\n",
    "            self.l0 = L.Linear(obs_size, n_hidden_channels)\n",
    "            self.l1 = L.Linear(n_hidden_channels, n_hidden_channels)\n",
    "            self.l2 = L.Linear(n_hidden_channels, n_actions)\n",
    "\n",
    "    def __call__(self, x, test=False):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            x (ndarray or chainer.Variable): An observation\n",
    "            test (bool): a flag indicating whether it is in test mode\n",
    "        \"\"\"\n",
    "        h = F.tanh(self.l0(x))\n",
    "        h = F.tanh(self.l1(h))\n",
    "        return chainerrl.action_value.DiscreteActionValue(self.l2(h))\n",
    "\n",
    "obs_size = env.observation_space.shape[0]\n",
    "n_actions = env.action_space.n\n",
    "q_func = QFunction(obs_size, n_actions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you want to use CUDA for computation, as usual as in Chainer, call `to_gpu`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Uncomment to use CUDA\n",
    "# q_func.to_gpu(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also use ChainerRL's predefined Q-functions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "_q_func = chainerrl.q_functions.FCStateQFunctionWithDiscreteAction(\n",
    "    obs_size, n_actions,\n",
    "    n_hidden_layers=2, n_hidden_channels=50)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As in Chainer, `chainer.Optimizer` is used to update models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Use Adam to optimize q_func. eps=1e-2 is for stability.\n",
    "optimizer = chainer.optimizers.Adam(eps=1e-2)\n",
    "optimizer.setup(q_func)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A Q-function and its optimizer are used by a DQN agent. To create a DQN agent, you need to specify a bit more parameters and configurations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Set the discount factor that discounts future rewards.\n",
    "gamma = 0.95\n",
    "\n",
    "# Use epsilon-greedy for exploration\n",
    "explorer = chainerrl.explorers.ConstantEpsilonGreedy(\n",
    "    epsilon=0.3, random_action_func=env.action_space.sample)\n",
    "\n",
    "# DQN uses Experience Replay.\n",
    "# Specify a replay buffer and its capacity.\n",
    "replay_buffer = chainerrl.replay_buffer.ReplayBuffer(capacity=10 ** 6)\n",
    "\n",
    "# Since observations from CartPole-v0 is numpy.float64 while\n",
    "# Chainer only accepts numpy.float32 by default, specify\n",
    "# a converter as a feature extractor function phi.\n",
    "phi = lambda x: x.astype(np.float32, copy=False)\n",
    "\n",
    "# Now create an agent that will interact with the environment.\n",
    "agent = chainerrl.agents.DoubleDQN(\n",
    "    q_func, optimizer, replay_buffer, gamma, explorer,\n",
    "    replay_start_size=500, update_interval=1,\n",
    "    target_update_interval=100, phi=phi)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now you have an agent and an environment. It's time to start reinforcement learning!\n",
    "\n",
    "In training, use `agent.act_and_train` to select exploratory actions. `agent.stop_episode_and_train` must be called after finishing an episode. You can get training statistics of the agent via `agent.get_statistics`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "episode: 10 R: 10.0 statistics: [('average_q', 0.033504104378229274), ('average_loss', 0)]\n",
      "episode: 20 R: 11.0 statistics: [('average_q', 0.05938679797261194), ('average_loss', 0)]\n",
      "episode: 30 R: 12.0 statistics: [('average_q', 0.08880156670371239), ('average_loss', 0)]\n",
      "episode: 40 R: 10.0 statistics: [('average_q', 0.10745671682940243), ('average_loss', 0)]\n",
      "episode: 50 R: 11.0 statistics: [('average_q', 0.1605578976152919), ('average_loss', 0.16413010902532935)]\n",
      "episode: 60 R: 12.0 statistics: [('average_q', 0.28350504480571603), ('average_loss', 0.32596588074553107)]\n",
      "episode: 70 R: 10.0 statistics: [('average_q', 0.4630023982307997), ('average_loss', 0.31781499766633914)]\n",
      "episode: 80 R: 10.0 statistics: [('average_q', 0.8794745899122342), ('average_loss', 0.1974944465810803)]\n",
      "episode: 90 R: 17.0 statistics: [('average_q', 1.4185744255580557), ('average_loss', 0.19480606991295596)]\n",
      "episode: 100 R: 43.0 statistics: [('average_q', 3.127327862835677), ('average_loss', 0.30055178911896646)]\n",
      "episode: 110 R: 111.0 statistics: [('average_q', 8.550209231394344), ('average_loss', 0.3159818577495354)]\n",
      "episode: 120 R: 200.0 statistics: [('average_q', 14.906816557700848), ('average_loss', 0.28902660652099954)]\n",
      "episode: 130 R: 200.0 statistics: [('average_q', 18.18435279855357), ('average_loss', 0.24366831958342194)]\n",
      "episode: 140 R: 200.0 statistics: [('average_q', 19.64378118841905), ('average_loss', 0.2305001637897665)]\n",
      "episode: 150 R: 200.0 statistics: [('average_q', 20.28042925179784), ('average_loss', 0.16466828120322005)]\n",
      "episode: 160 R: 200.0 statistics: [('average_q', 20.525837934804745), ('average_loss', 0.14458694507747866)]\n",
      "episode: 170 R: 200.0 statistics: [('average_q', 20.48710300430637), ('average_loss', 0.1437242373458598)]\n",
      "episode: 180 R: 200.0 statistics: [('average_q', 20.363789095886585), ('average_loss', 0.15989901066010842)]\n",
      "episode: 190 R: 200.0 statistics: [('average_q', 20.279585415986432), ('average_loss', 0.19724599736355053)]\n",
      "episode: 200 R: 102.0 statistics: [('average_q', 20.325483683521487), ('average_loss', 0.13446687501404633)]\n",
      "Finished.\n"
     ]
    }
   ],
   "source": [
    "n_episodes = 200\n",
    "max_episode_len = 200\n",
    "for i in range(1, n_episodes + 1):\n",
    "    obs = env.reset()\n",
    "    reward = 0\n",
    "    done = False\n",
    "    R = 0  # return (sum of rewards)\n",
    "    t = 0  # time step\n",
    "    while not done and t < max_episode_len:\n",
    "        # Uncomment to watch the behaviour\n",
    "        # env.render()\n",
    "        action = agent.act_and_train(obs, reward)\n",
    "        obs, reward, done, _ = env.step(action)\n",
    "        R += reward\n",
    "        t += 1\n",
    "    if i % 10 == 0:\n",
    "        print('episode:', i,\n",
    "              'R:', R,\n",
    "              'statistics:', agent.get_statistics())\n",
    "    agent.stop_episode_and_train(obs, reward, done)\n",
    "print('Finished.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "Now you finished training the agent. How good is the agent now? You can test it by using `agent.act` and `agent.stop_episode` instead. Exploration such as epsilon-greedy is not used anymore."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test episode: 0 R: 200.0\n",
      "test episode: 1 R: 200.0\n",
      "test episode: 2 R: 200.0\n",
      "test episode: 3 R: 200.0\n",
      "test episode: 4 R: 200.0\n",
      "test episode: 5 R: 200.0\n",
      "test episode: 6 R: 200.0\n",
      "test episode: 7 R: 200.0\n",
      "test episode: 8 R: 200.0\n",
      "test episode: 9 R: 200.0\n"
     ]
    }
   ],
   "source": [
    "for i in range(10):\n",
    "    obs = env.reset()\n",
    "    done = False\n",
    "    R = 0\n",
    "    t = 0\n",
    "    while not done and t < 200:\n",
    "        env.render()\n",
    "        action = agent.act(obs)\n",
    "        obs, r, done, _ = env.step(action)\n",
    "        R += r\n",
    "        t += 1\n",
    "    print('test episode:', i, 'R:', R)\n",
    "    agent.stop_episode()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "If test scores are good enough, the only remaining task is to save the agent so that you can reuse it. What you need to do is to simply call `agent.save` to save the agent, then `agent.load` to load the saved agent."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Save an agent to the 'agent' directory\n",
    "agent.save('agent')\n",
    "\n",
    "# Uncomment to load an agent from the 'agent' directory\n",
    "# agent.load('agent')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "RL completed!\n",
    "\n",
    "But writing code like this every time you use RL might be boring. So, ChainerRL has utility functions that do these things."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "outdir:result step:167 episode:0 R:167.0\n",
      "statistics:[('average_q', 20.337137631400005), ('average_loss', 0.14815392158495394)]\n",
      "outdir:result step:367 episode:1 R:200.0\n",
      "statistics:[('average_q', 20.32523375841225), ('average_loss', 0.14421337271314047)]\n",
      "outdir:result step:567 episode:2 R:200.0\n",
      "statistics:[('average_q', 20.322520705385344), ('average_loss', 0.11125931680483415)]\n",
      "outdir:result step:600 episode:3 R:33.0\n",
      "statistics:[('average_q', 20.318184376935776), ('average_loss', 0.13332111934756039)]\n",
      "outdir:result step:800 episode:4 R:200.0\n",
      "statistics:[('average_q', 20.35352763185174), ('average_loss', 0.12591677219894507)]\n",
      "outdir:result step:991 episode:5 R:191.0\n",
      "statistics:[('average_q', 20.344118637007757), ('average_loss', 0.14653855789577244)]\n",
      "outdir:result step:1037 episode:6 R:46.0\n",
      "statistics:[('average_q', 20.35218652983029), ('average_loss', 0.13151683572695283)]\n",
      "test episode: 0 R: 200.0\n",
      "test episode: 1 R: 200.0\n",
      "test episode: 2 R: 200.0\n",
      "test episode: 3 R: 200.0\n",
      "test episode: 4 R: 200.0\n",
      "test episode: 5 R: 200.0\n",
      "test episode: 6 R: 200.0\n",
      "test episode: 7 R: 200.0\n",
      "test episode: 8 R: 200.0\n",
      "test episode: 9 R: 200.0\n",
      "The best score is updated -3.40282e+38 -> 200.0\n",
      "Saved the agent to result/1037\n",
      "outdir:result step:1237 episode:7 R:200.0\n",
      "statistics:[('average_q', 20.307884243495028), ('average_loss', 0.13172055951951084)]\n",
      "outdir:result step:1430 episode:8 R:193.0\n",
      "statistics:[('average_q', 20.321500150392207), ('average_loss', 0.12309285619445343)]\n",
      "outdir:result step:1514 episode:9 R:84.0\n",
      "statistics:[('average_q', 20.32210402767951), ('average_loss', 0.1512905636217192)]\n",
      "outdir:result step:1579 episode:10 R:65.0\n",
      "statistics:[('average_q', 20.32823341815203), ('average_loss', 0.1405936792339856)]\n",
      "outdir:result step:1737 episode:11 R:158.0\n",
      "statistics:[('average_q', 20.303456988399578), ('average_loss', 0.125684738660057)]\n",
      "outdir:result step:1791 episode:12 R:54.0\n",
      "statistics:[('average_q', 20.29044738590696), ('average_loss', 0.1366198758930306)]\n",
      "outdir:result step:1991 episode:13 R:200.0\n",
      "statistics:[('average_q', 20.259623728542685), ('average_loss', 0.11972923205456666)]\n",
      "outdir:result step:2000 episode:14 R:9.0\n",
      "statistics:[('average_q', 20.26092540709551), ('average_loss', 0.1261842990180651)]\n",
      "test episode: 0 R: 200.0\n",
      "test episode: 1 R: 200.0\n",
      "test episode: 2 R: 200.0\n",
      "test episode: 3 R: 200.0\n",
      "test episode: 4 R: 200.0\n",
      "test episode: 5 R: 200.0\n",
      "test episode: 6 R: 200.0\n",
      "test episode: 7 R: 200.0\n",
      "test episode: 8 R: 200.0\n",
      "test episode: 9 R: 200.0\n",
      "Saved the agent to result/2000_finish\n"
     ]
    }
   ],
   "source": [
    "# Set up the logger to print info messages for understandability.\n",
    "import logging\n",
    "import sys\n",
    "logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='')\n",
    "\n",
    "chainerrl.experiments.train_agent_with_evaluation(\n",
    "    agent, env,\n",
    "    steps=2000,           # Train the agent for 2000 steps\n",
    "    eval_n_steps=None,       # We evaluate for episodes, not time\n",
    "    eval_n_episodes=10,       # 10 episodes are sampled for each evaluation\n",
    "    train_max_episode_len=200,  # Maximum length of each episode\n",
    "    eval_interval=1000,   # Evaluate the agent after every 1000 steps\n",
    "    outdir='result')      # Save everything to 'result' directory"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "That's all of the ChainerRL quickstart guide. To know more about ChainerRL, please look into the `examples` directory and read and run the examples. Thank you!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
